{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b96d9441",
   "metadata": {},
   "source": [
    "## Graph Convolution LSTM\n",
    "Standard LSTM cell with Chebyshev graph convolution. <br />\n",
    "Very similar to the GCLSTM, but seems to perform better. <br />\n",
    "\n",
    "Usage is analogous to a standard pytorch LSTMCell. <br />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1963b49",
   "metadata": {},
   "source": [
    "### GConvLSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7e473b3a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [16:52<00:00,  5.06s/it]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA1n0lEQVR4nO3deXycZb338c9vZjKTbZI0yWRfu6Ql3du0ZSubUChC2cQDisLjgvqAirgcPCoiR32U4+EoinoQET1HLIsoRQsFyl6gNN33NE3SJM2+78sk1/PHTKZJkzRpO8lk+b1fr7w6c9/3zPwySb9z5bqv+7rEGINSSqnJzxLoApRSSvmHBrpSSk0RGuhKKTVFaKArpdQUoYGulFJThC1QLxwbG2syMjIC9fJKKTUpbd++vcYY4xpqX8ACPSMjg9zc3EC9vFJKTUoicmy4fdrlopRSU4QGulJKTREa6EopNUVooCul1BShga6UUlOEBrpSSk0RGuhKKTVFTLpA31XSwE9fPhToMpRSasKZdIG+t7SB37x5lP1ljYEuRSmlJpRJF+jXLk7CbrXw3PbSQJeilFITyqQL9KhQO1fMj+eFXWV0uXsBeP9oLUerWwJcmVJKBdakC3SAjy1Poa61i9cPVdHba/ji/27n4VfyAl2WUkoF1KQM9NWzY4kNt/PyvnKKaltpbO+msKY10GUppVRATcpAt1ktrJoZw9bCOnaVNABwrLYVXfBaKTWdTcpABzh3ZgzljR38Y085AK1dPdS2dgW4KqWUCpzJG+iZ0QC8fqiKIKsAcKy2LZAlKaVUQE3aQJ8dF05MmB2Ai7M8i3ccq9V+dKXU9DVpA11EWOltpV+zKAkRbaErpaa3SRvoABdlubBZhFUzo0mKDKG4TgNdKTV9BWxNUX/4l5xULpgVS2JkCGnRoRRpl4tSahqb1C10i0VIiwkFICM2lOKTulzeP1rL1oLaQJSmlFLjblSBLiJXichhEckXkfuG2P9fIrLL+5UnIg1+r3QEadFh1LZ20dzRDUBPr+Gr63fygxcPjHcpSikVECMGuohYgUeBtUA2cKuIZPc/xhjzNWPMEmPMEuCXwPNjUOspLUiOAODXbx4FYEt+DVXNneRXt+Du6R3vcpRSatyNpg99JZBvjCkAEJH1wHXAcE3fW4Hv+6e80btwdiyfWJXGb948Slp0qK+rpcvdy7G6Nma5wse7JKWUGlej6XJJBkr63S/1bhtERNKBTOD1YfbfKSK5IpJbXV19urWekojwg3XzWT0nlm8/v5cX95SzODUKgMMVzX59LaWUmoj8fVL0FuA5Y0zPUDuNMY8ZY3KMMTkul8vPLw1BVgtP3LGCO87PwCLw3Y+eg0U00JVS08NoulyOA6n97qd4tw3lFuCusy3qbARZLTywbj73rZ1HcJCVjJgw8io10JVSU99oWujbgDkikikidjyhveHkg0RkHjADeN+/JZ6Z4CArAFnxTg5roCulpoERA90Y4wbuBjYBB4FnjDH7ReRBEVnX79BbgPVmgs1hm5XgpKimlY7uwb1AzR3dOuWuUmrKGFUfujFmozEmyxgzyxjzI++2+40xG/od84AxZtAY9UCbG++k10B+1cAl6qqbO1nxo9fYtL8yQJUppZR/TeorRUdjXqITgIPlTQO27yppoKO7V9ciVUpNGVM+0DNjwgizW9l3vHHA9r3e+7UtuiiGUmpqmPKBbrEI85MifQHepy/g61o7A1GWUkr53ZQPdIAFyZEcKG+ip/fECVBfC12XrVNKTRHTJNAjBvSXVzZ1UN3saZlrl4tSaqqYFoG+MDkSgL2ljQP+neUKo1a7XJRSU8S0CPSZrnBCgqy+bpa9xxuxCKye46KutUvHoiulpoRpEehWi5CdFMG7+TWUNbTz1IfFLE6NImVGCN09huZOd6BLVEqpszYtAh3gzotmcrS6hSsefouGti5+dP1CosPsgPajK6WmhmkT6FfOT+BH1y+ktauHr6+ZS3ZSBDHhDkCHLiqlpoZJvUj06frEqjTWzI8n1hvkMd4Weo220JVSU8C0aaH36QtzgJhwT6DX6Vh0pdQUMO0Cvb8Tfeja5aKUmvymdaA7bFacDpteLaqUmhKmdaADRIfbqWvtYndJA21dOnxRKTV5aaCH2Xn/aC3XPbqFR9/IH7T/t28d5b2jNQGoTCmlTs+0D/SYMAdV3nldNuwuG3DVaFlDOz956RD3rN9Fi158pJSa4KZ9oMd6R7qsW5xESV07u0tPTLO7cW85AFXNnTyy+UhA6lNKqdGa9oF+27np/OiGBfz79QuwWy38Y3eZb9/GveWckxjBx3NSeOLdQh0No5Sa0KbVhUVDWZAcyQLvbIwXZbl4dnspLqeDeYkR7Chu4BtrslieHs0zuaXsPd7IJXPjAlyxUkoNbdq30Pv7+posMmPD+H8vHeL2Jz4EYO3CRLKTIgDYX9Z0qocrpVRATfsWen/nJEbw97suoLS+jX3HmxCBWa5wAFKjQzigga6UmsBGFegichXwC8AKPG6M+ckQx3wceAAwwG5jzCf8WOe4SpkRSsqM0AHbFiRFsq+scZhHKKVU4I3Y5SIiVuBRYC2QDdwqItknHTMH+DZwgTFmPnCP/0sNrPlJERyrbaOpozvQpSil1JBG04e+Esg3xhQYY7qA9cB1Jx3zeeBRY0w9gDGmyr9lBt78JM+J04Pa7aKUmqBG0+WSDJT0u18KrDrpmCwAEdmCp1vmAWPMyyc/kYjcCdwJkJaWdib1Bsz8ZM+J0b/tPM6HhXUcrGjikrlxfDwnNcCVKaWUh79OitqAOcAlQArwtogsNMY09D/IGPMY8BhATk7OpFrIM84ZTGJkMOu3eT7bwuxW3jlSw0cXJhLm0HPLSqnAG00SHQf6N0NTvNv6KwW2GmO6gUIRycMT8Nv8UuUE8T+fXUldazfnJDrJq2zhpt+8x3PbS7n9/IxAl6aUUqPqQ98GzBGRTBGxA7cAG0465u94WueISCyeLpgC/5U5McyOc7IyMxpncBDL02ewJDWKP2wppLd3Uv2xoZSaokYMdGOMG7gb2AQcBJ4xxuwXkQdFZJ33sE1ArYgcAN4AvmmMqR2roieKO87PoKi2jV2lDYEuRSmlRteHbozZCGw8adv9/W4b4F7v17SxKMUz8qWoppVlaTMCXI1SarrTS//PQvKMEESguK4t0KUopZQG+tlw2KwkRgRroCulJgQN9LOUGh1KSV0bRTWtLHpgE/uO6/QASqnA0EA/S2nRoRTXtfFufg1NHW5ePzTlLpJVSk0SGuhnKS06lMqmTrYW1gHwofffkxVUt/DK/orxLE0pNc1ooJ+ltBjPrIybD1YCsP1YPd09vYOOe+ztAu56agcd3T3jWp9SavrQQD9LqdGeQG/r6mFuvJP27p4hF8Ioa+ygu8foIhlKqTGjgX6W0qJPzJv+2dWZAHxYOPiaqvKGdgB2FtePT2FKqWlHA/0sxYTZCbVbAbj8nHhmxoYN2Y9e0dgBwM6ShvEsTyk1jWignyURIS06lLToUKLD7CxJi2LvSUMXmzu6ae50A7CruGHI5zlY3kRda9dYl6uUmsI00P3gCxfP5J7L5wAwN95JZVMnjW0nVjbqa50vTo3ieEM7VU0dAx7v7unl4799n0c2Hxm/opVSU44Guh/csDSFG5elAJCV4AQgr6rZt7/cG+hXL0gABne75Fe30NzppkSvOFVKnQUNdD/LivcE+uGKE4He10L/yDlxAOT12wewp8TTRdMX/EopdSY00P0sKTKYcIeNvMoToV3W6BnhkhYdRmRIEDUtnQMes+d4AwCVTRroSqkzp4HuZyJCVnz4oBZ6bLgDu81CbLid6pMDvdTTQq9t7fJdeLSrpIEVP3ptUPgrpdRwNNDHwNwEJ3mVzXimifd0pSRGBgMQG+6gpvnEaJZOdw8Hy5uIDXcAUNXkCfA9pQ1UN3dyqLwZpZQaDQ30MTAnzkl9Wzc1LZ7gLm9sJ8Eb6C6nY0AL/XBFM909hiuy433HAtQ0e44prdcTpUqp0dFAHwNzvSNd9pWdONmZNKCFfiLQ+7pb1ngDvcLbj17t/TAo0UBXSo2SBvoYWJwaRWy4nYdfyaOsoZ3mDjcJkSGAp4Xe3On29ZXnV7UQZreSk+FZwq5vREy1r4XeHoDvQCk1GWmgj4Fwh40frFvA3uONXP7wWzhsFi6cHQuAy9tX3hfYBTWtZLrCcAYH4XTYfEMX+06GaqArpUZLA32MXL0wgWsXJxERHMQzXziPhd4FpV1Ob6B7A7uwpoXM2HAAEiKDfS30vkDXi42UUqNlC3QBU5WI8It/WUKvMdisJz43+0az1DR30unuobS+nRuWeq4yTYgMprypA2MM1c2dWASqmjvp6O4hOMgakO9DKTV5jKqFLiJXichhEckXkfuG2H+HiFSLyC7v1+f8X+rkY7HIgDCHgS304to2jIGZsWEAJEQEU9nYQUunm053r++q0+MN2u2ilBrZiIEuIlbgUWAtkA3cKiLZQxz6tDFmiffrcT/XOWXEhNsBqGnu4mh1KwAzXZ5AT4wMpqq5w3fF6NI0z4lS7UdXSo3GaFroK4F8Y0yBMaYLWA9cN7ZlTV1BVgszQoOobumgsMYT6BneFnpSVAi9BnZ553ZZmhoF6Fh0pdTojCbQk4GSfvdLvdtOdpOI7BGR50QkdagnEpE7RSRXRHKrq6vPoNypoe9q0cKaFmLDHUQEBwGwINlz4rRvfdLspAiCrEJJnbbQlVIj89colxeBDGPMIuBV4I9DHWSMecwYk2OMyXG5XH566cknNtxztWhhTauv/xxgXoKT4CALb+V5PuziI4JJjgrRFrpSalRGE+jHgf4t7hTvNh9jTK0xpu/yx8eB5f4pb2pyOR2U1LWRV9lCZr9At1ktLEqJoq2rB4tAdJiduIhgqpp1gi6l1MhGE+jbgDkikikiduAWYEP/A0Qksd/ddcBB/5U49bicDqq8wxY/uihxwL6laVEARIc5sFoEV7hDZ1xUSo3KiOPQjTFuEbkb2ARYgSeMMftF5EEg1xizAfiKiKwD3EAdcMcY1jzp/cuKVELtVj65Kt03aVefpamekS2x3tEwMeH2AXO/KKXUcEZ1YZExZiOw8aRt9/e7/W3g2/4tberKinfy9TVzh9zX10LvG68eG+6gqcNNp7sHh00vLlJKDU8v/Z9g4iOCmekK8/Wt911ZWtvSdaqHKaWUXvo/ET37hfMIsXta431dL7UtXSRFhQSyLKXUBKeBPgHFeFvlALHerhc9MaqUGol2uUxwvul2NdCVUiPQQJ/gfHO/aKArpUaggT7BhdpthNqtAxaWVkqpoWigTwKx4Q5qW7WFrpQ6NQ30SSA23K5dLkqpEWmgTwJ9szMqpdSpaKBPArFOnc9FKTUyDfRJIDbMTl1bF+6e3kCXopSawDTQJ4FYpwNjPAtGG2MCXY5SaoLSQJ8E+uZzWf3QG3xl/a7AFqOUmrA00CeBFRnRXDk/npQZIew/3hjocpRSE5QG+iTgcjr470/lcMU58ZQ1tmu3i1JqSBrok0hiVAgd3b00tHUHuhSl1ASkgT6JJHpXNyprbA9wJUqpiUgDfRLpC/Tyho4AV6KUmog00CeRvgUuyrWFrpQaggb6JBIb7sBmEcobT7TQD5Y30dHdE8CqlFIThQb6JGK1CPERwb5A31PawNpfvMPlD7/FW3nVAa5OKRVoGuiTTFJUMGUNni6Xf+4tx2YRrBbh357fG+DKlFKBNqpAF5GrROSwiOSLyH2nOO4mETEikuO/ElV/iZEhlDd2YIxh074KzpsVw0cXJlLZ1EFvr45PV2o6GzHQRcQKPAqsBbKBW0Uke4jjnMBXga3+LlKdkBgVTEVjB4cqmimqbePK+Qm4nA7cvYaGdh2frtR0NpoW+kog3xhTYIzpAtYD1w1x3L8DPwV0TN0YSowIpqunlye3FCECa7LjcTm9C0k36xS7Sk1nown0ZKCk3/1S7zYfEVkGpBpj/nmqJxKRO0UkV0Ryq6v1JN6ZSPQOXXw6t4RLslzERQTjCtdAV0qB7WyfQEQswMPAHSMda4x5DHgMICcnRzt8z8CilEhSo0O4YUkyd102G+BEC71F/zhSajobTaAfB1L73U/xbuvjBBYAb4oIQAKwQUTWGWNy/VWo8kiMDOGdb102YFtfoOsydUpNb6PpctkGzBGRTBGxA7cAG/p2GmMajTGxxpgMY0wG8AGgYT6Owh02goMsVOsydUpNayMGujHGDdwNbAIOAs8YY/aLyIMism6sC1QjExFcTof2oSs1zY2qD90YsxHYeNK2+4c59pKzL0udLle4BrpS051eKTpFaAtdKaWBPkW4nA7tQ1dqmtNAnyJc4cHUtXbR3dMb6FKUUgGigT5FxDrtANS26NBFpaYrDfQpQq8WVUppoE8RfRcX/fL1I/z8tbwAV6OUCgQN9CkiMdIzx8srByr5+WtHaGzTmReVmm400KeIhMhg/vy5VfznzYsB2F5chzEGY3TKHKWmCw30KeSC2bFcvTARm0XILarn9+8WsvqhN3D3G/lyoKyJulY9carUVKSBPsWE2K3MT45ka2EdT7xbSGl9O4cqmgEoqWvj+l9v8fWx7yltoKSuLZDlKqX8SAN9ClqRPoPtx+op8y4mnVtUB8BPXz5El7uXI5UtANz5p+385KVDAatTKeVfGuhTUE5GNADRYXbiIxzkHqtnR3E9/9hTjt1moai2lcb2biqaOjhc2RzgapVS/qKBPgXlZMzAahFuXp7Cioxoth+r5+evHSE23M5nLsikvLGD/WWNABTVtNLl7mXj3nK25NeM+NytnW6uePgt3j9aO9bfhlLqNGmgT0Gx4Q5euOsCvnZFFjnpMyhv7ODtvGo+e+FMspMiAHjjUBUA7l5DUW0r3/v7Ph7ZfGTE5z5c2cyRqpZRhb9Sanyd9RJ0amJakBwJnOh+iQi2cdu5aRTWtAKw+WCV79hN+yqobe0iuL59xOct8j6+73mUUhOHttCnuHkJTpKjQvjSJbNxBgeRERsGQEFNK2nRoYjA+m2eNcDLG9vpcp96cq+i2jbf45VSE4u20Kc4m9XCO9+6FM9yrxARHERMmJ3a1i7mJ0UgAse8Id1roKyh3Rf6QznRQm+ht9dgsciYfw9KqdHRFvo0YLEIIieCNz0mFIBZrnDmxIUDEBvuma2xpN4T7m1dbh5/p4C2LveA5yqq9QR6R3cvFU0dY167Umr0NNCnob4W+Ky4MGbHOQFYtzgZgJK6dowxfOu5PfzwnwfZuLfC9zhjDIU1rcz2fghoP7pSE4sG+jSUGeMNdFc45yR6Av26JUkEWYXiujb+sKWIf+wpB2D7sTrf4+pau2jucHPpXBeg/ehKTTTahz4NrZmfwOHKZuYmOMlOjCAxMoTFqVEkR4VQXNfK8ztKuXB2LDarsK2o3ve4vhOiqzJj+N8PiimobgnUt6CUGoK20KehuQlOfvWJZThsVmxWCyszPUMbU6NDefNwNVXNndyc47koKb+qhXrvZF59J0QzXWFkxoZpl4tSE8yoAl1ErhKRwyKSLyL3DbH/iyKyV0R2ici7IpLt/1LVWEuNDqWtqwe71cJl8+LISZ8BwPZjnlZ6UW0rFoHUGaFkusIoqNZAV2oiGTHQRcQKPAqsBbKBW4cI7KeMMQuNMUuAh4CH/V2oGntp0Z7RL6vnxOIMDmJxahRBVmGbtx+9oKaVlBmh2G0WFiZHUlzXRkWjjnRRaqIYTQt9JZBvjCkwxnQB64Hr+h9gjGnqdzcM0FUVJqG+QL9qQQIAwUFWFiZHkuvtRz9a1eIb4XLp3DgAXj9UNcQzKaUCYTSBngyU9Ltf6t02gIjcJSJH8bTQvzLUE4nInSKSKyK51dXVZ1KvGkMXZ7m494osrl2c5Nu2KCWKg+VNdPf0UtBvyGJWfDjJUSG+QDfGcNXP3+bJLYUBqV0p5ceTosaYR40xs4B/Bb47zDGPGWNyjDE5LpfLXy+t/CTMYeMrH5lDcJDVty07MYK2rh7eza+hy93LbJcn0EWEy+bFsSW/ho7uHiqbOjlU0cwjr+fT3tUTqG9BqWltNIF+HEjtdz/Fu20464Hrz6ImNYH0zc744u4ywHMxUp/L5sXR3t3DBwW1vhEvda1drN9WPP6FKqVGFejbgDkikikiduAWYEP/A0RkTr+7HwVGnodVTQqz48KxWYRX9ld67rucvn3nzYrBZhG2Ftb5pgTIjA3jsbcLBqxjqpQaHyMGujHGDdwNbAIOAs8YY/aLyIMiss572N0isl9EdgH3ArePVcFqfAUHWZkdF05Lp5vYcAeRoUGD9h0sb6KothW71cI3r5xLeWMH7+kCGEqNu1FdKWqM2QhsPGnb/f1uf9XPdakJ5JzECA5VNDM7bvAsjNlJEbxzpAaHzUJaTCiXzYvD6bDx4u4yLsrS8yRKjSe9UlSNKDvR04/eN8Ll5H3VzZ1sP9ZARkwYwUFW1sxP4OX9FXS69eSoUuNJA12NqO/EaN8Il6H21bR0khnrGce+bkkSzR1u3jysQ1OVGk8a6GpEy9NncNu5aVy1IHHQvvmJkb7b6d5ZHM+fFYPL6eDXb+SP+uTo23nVvLDrVIOnlFIj0UBXIwoOsvLD6xeSEBk8aF9kaBDJUSGAZ4QLQJDVwv3XZLO7tJEnhrjQqKalk8fePkpH94kumYdfzeO+v+6ltdM96Hil1OhooKuz1tft0n/pumsWJbImO57/fCWPmpZO3/bq5k5ufewDfrzxEE9t9YxXd/f0crC8ifbuHl45UIFS6sxooKuztnpOLEmRwSRGnGjBiwj3rsmi093LS3vLfdvveXonpfXtZMaG8cf3i+jtNRypaqHTuzj18zu020WpM6WBrs7ap85NZ8t9lw1aMHpuvJM5ceG86F39qKXTzQcFdXzmwgzuvSKLY7VtvJlXxd7SRgA+uiiRLfk1/H3ncZ7fUcrPNh3W2RyVOg26YpE6a/0XoD55+zWLkvj55jwqGjs4WN5ET6/h/FmxrMyMJiEimMffKWSWK5xwh42vX5HFK/sruOfpXb7ncPca7ls7b9jXbuty47BZsVqGrkGp6URb6GpMXbM4EWPgn3vL+aCgFrvVwvL0GQRZLXxudSbvHa3lH3vKmJ8UwUxXOLnfvYKX71nNq1+7iPNnxfDawcphn9sYw5U/f5tvPrd7HL8jpSYuDXQ1pma5wlmSGsXv3i7gzcPVLEmL8s3m+MlV6cSGO6hv62Zhsmf4Y2RIEPMSIpgT72RNdjz5VS3DLnVXUtdOSV07z+84zuZTBL9S04UGuhpz37smm4qmDg5XNnPezBjf9hC7lS9ePBOAhSmRgx73kXPiAXjtwNBhvbPEs/BGbLid7/xt34BhkP29vK/cN1ukUlOZBroac8vTZ/AvOZ4ZmM/tF+gAnzovnX+/bj5Xzk8Y9LjU6FDmJTh5cU8ZzR3dg/bvLG4gJMjKdz/q+cDIr2oZ8vU9Y9z30Ng++DmUmko00NW4+O415/DQTYtYlRk9YLvDZuVT52UMWFSjv0+em86e0kYueugNdhTXD9i3q6SBhSmRzE3wTOnbN4Vvf13uXgqqW2nt6uEvH+o87Wpq00BX48IZHMTHV6QOGto4kk+dm86Guy8A4A9binzbO909HChrYmlqFBneKQcKqwcHemFNK+5eQ3CQhSe3FNHl1nna1dSlga4mvEUpUazJTuDNQ1W+QD5Q1kRXTy9LUqMIsVtJjAymcIgW+uHKZgC+fNkcKpo6eOeIThimpi4NdDUpXJEdT3Onm62FnoUzdhQ3ALAkLQqAjJiwAaNh9pY2kl/VQl5FM1aLcNu56Vgtwk7v45SaivTCIjUpXDgnlpAgK68eqGT1HBfvHKkmMzaMxEjvxGCuMN8UA+1dPXz6ia24nA7SosPIjA0jMiSIufFOdpc2BPC7UGpsaQtdTQrBQVZWz4nllf2VtHW5+aCglov7rYiUGRNGfVs3DW1dPLu9hPq2bvIqW3g7r5q58Z6TpkvSothd0kBvrwnUt6HUmNJAV5PGjcuSqWjq4Psv7Keju5eL5/YLdO9Mj0erW3j8nUIWJkcSGRJEV08vWX2BnhJFU4d7yNEw6ux09/RijH5QBpoGupo01mQnMDsunGe3l+KwWQZcpNQ3de9PXzpMcV0b//eSWdy8PAWAuQmelZYWp0YBnuGOyn863T2s+vFmns0tDXQp054Gupo0LBbhrktnAbBqZsyAsetp0aFYBD4squPyc+K5cn4Cn79oJtcvSeK8WbGAZ03UMLuVDwvrqG7uHPI1pqK61i4+LKwbs+evauqkrrXrlPPuqPGhga4mlWsXJbEmO57bVqUN2G63WciKd7I0LYpf3roUi0WIjwjm57csJTIkCACrRViUEsX6bSWs+NFrvLL/9BbTOFrdwn1/3TPpxrI/+V4Rt/7ugyGvtvWHKu+HY+6xeu12CbBRBbqIXCUih0UkX0TuG2L/vSJyQET2iMhmEUn3f6lKgc1q4bFP57BmiKkCnv3ieTx953mE2Ie+6hTgB9fN54Frs4kNt/PCrtOb3+WZbSWs31bC3uMNp1t2QNW0dNLTa9jjnXfe3/r+2qlr7eLoEBd3qfEzYqCLiBV4FFgLZAO3ikj2SYftBHKMMYuA54CH/F2oUiNxBgdht536Vzor3skdF2SydkEirx+qor1r8IRe6z8s5vKH3xo02dcHBZ4x8LtKxiYYx0pjm6dlvvOkqRP8pbr5xCIkuUVj17WjRjaaFvpKIN8YU2CM6QLWA9f1P8AY84Yxps179wMgxb9lKuVfaxck0N7dw1t5VYP2vbCrjPyqFv6288RyeM0d3ewrawIm30nVhvYuwDOZWXtXz6A5cc5WdXMnIhAdZmdb0dh8aKjRGU2gJwMl/e6XercN57PAS0PtEJE7RSRXRHKrq/USbBU4KzOjiQ6z82xu6YCWeHtXD9uPeULpiXcLfX3Cucfq6ek1xIY72D3ZAr2vhV7SwIP/2M/HfvMe9a1dfnv+quZOYsIcrMiYwTZtoQeUX0+KishtQA7wH0PtN8Y8ZozJMcbkuFyuoQ5RalzYrBZuzklh86Eqzvt/m/nje54Fq7cW1tLV08sNS5M5UtXC20dqANhaUEeQVbjt3DSK69qobZk8o2Qa2roJsgp1rV2s31ZCrzkxx40/VDd3Eud0sDg1iuK6tjE7+apGNppAPw6k9ruf4t02gIhcDnwHWGeMmTy/7Wrauu+qeTz1uVUsSI7k+xv2c9vvt7JpfyV2q4UH1s0nzung8XcKAHi/oJbFKVG++dz/sKWIbz+/l9ZON41t3fxs0+EJG/KN7d2syPBMWxxk8fyXP+LPQG/pxOV0kOmd9fJYbdsIj1BjZTSBvg2YIyKZImIHbgE29D9ARJYC/40nzAd3Sio1AYkI58+O5U+fWclPblzI1sI6/vJhMTkZM4gMCeLT56XzzpEantxSyO6SBj5yTjwLkyOxCPzqjXz+8mExL+wq40/vF/GrN/L5/J9yh101KVC6e3pp6XSzIiOa5KgQ7r5sNs5gG3mVQy8GciaqmjyB3ndxl16JGzgjBroxxg3cDWwCDgLPGGP2i8iDIrLOe9h/AOHAsyKyS0Q2DPN0Sk04IsItK9P41a1LsVmEK7I9S999YlU6DpuFB148QHpMKP/nggzCHDbWLU7ihqXJzI4L5+ncEp7ZXkJSZDA7ihv45nN7hhyL3eXu5Q9bCimo9l+QjkbfKk0x4Xbe/talfPmy2WTFO8nzUwu9t9dQ0+LpckmPCQWgaJg1YNXYG9Vsi8aYjcDGk7bd3+/25X6uS6lxt3ZhIrmzYnwXIkWH2blpeQpPbS3mwesW+K5M/fktSwH43dsF/GjjQQB+ccsSjje089DLh8mMDePeK7Jo6XTzvb/vY0VGNG8eruKVA5WEBFn58Y0LuGHpyAPB3sqrJtxhY3n6jDP+nvpOiEaGBGH1Li6SFR/Oy/sqMMYgcnoLjpysvq0Ld6/B5XQQarcRH+GgsEa7XAJFp89Vqp+oUPuA+/965TzWZMcPmNmxz/VLk/nJy4cIs1u5cn4CDpuFwupWHtl8hBUZMzhS6Rn62Df88Rtrsngrr5pvPbeHy+bGc6yulQdfPMChimauWZTIg9ct8I2j/5/3i/jeC/sBuGZRIg9/fMmIY+yH0ugdstj/+5oT5+QvbSXUtHThcjpO+zn7q/aeN4hzBgOeeemPaZdLwGigK3UKkaFBXDI3bsh9LqeDuy+dTUy43dd6/+ENC3i/oJaHXj5MS6ebpWlRfHPNXDrcPVw2L54L57i4/tEtvHqwkmdzSyioaeXiLBfrt5VQ1tjBk3es4K28ar73wn4uPyeOma5wHnu7gI8tT+GSuXE0tHUN+tA5lb4WepT3rw7AN/vkBwW1hAfbuHSY7280qpo6fe8FeAJ98yGd0yVQNNCVOgtfuyJrwH2Hzco9l2fxjWd3A56umPNnx/r2L06JJDkqhMfePkpeZQvfumou//eS2Sx9p4Af/vMg7xfU8tSHxcRHOPj1J5fT02v4w5ZC3i+oxW61cNvvt/LsF88fdTeML9BD+we6Z/bJr6zfiTHwjy9fyILkyDP6/vsu+4/rC/TYMGpaumju6MYZHHSqh6oxoJNzKeVnNyxNZpYrjNhwB2sXJA7YJyJ8dFEieZUt2CzCx7xT/N52bjrOYBtPvFvIm4eruHZREnabhRC7lSWpUbx/tJbndx6n18DT24pHXUtD+4k+9D4up4OUGSGkR4ciAq8eOPMWdd/EXH0t9MxYz4lRfwxdNMbQo4uRnBYNdKX8zGoRnvw/K3nq86uG7Pf+6EJPyK+ZH+/rew4OsnLt4iQ2H6qiu8ewbkmS7/jzZsaw73gjm/ZVIAIb91YMOQfNUBrbuhBhQGtZRPjnl1ez6WsXsSxtxml1kXT39PLV9Tu5/YkPAahobCfMbiXM4fljP907Fr3QDyNdfvryYdb96t1Bo4Yef6eAy3725qD3wBhDSd30PiGrga7UGEiNDvX1VZ9sUUok37xyLl9fM3fA9puWeVrrmbFhLOzXBXLurBh6DTR3urlz9UxaOt28cmB0U/82tHcTEXxihEufyNAgHDYrl58Tz77jTZQ3tg/YP9TVnj29hnue3sULu8p4K6+a6uZOdhQ3DOiu6Ru66I8To7tK6tlf1sShihNDLEvr2/jZK4cpqGnlmdySAcc/t72U1Q+9wdef2T1tr1bVQFdqnIkId106m1mu8AHbl6VFcUV2PHdeNHPAcMJlaTOw2yxEBNu4d00WyVEhrP/wRJiV1rdx7S/fZd/xwbNANrZ3D+g/P9nl53hOiG4+eOJ6wO3H6ljy4KtsPzZwXpZHNh/hn3vKuWGpZyqn1w9Vsr+skVWZ0b5j/Dl0saTO8yHz8r4TH14/9g4TnZfg5HfvFODuOTE3/XtHa3HYLPxtZynf37D/rF9/MtJAV2qCEBF+9+kcbl05cPGO4CArn1iZxudWz8Rhs3L7+em8X1DrC9wfvHiAvccbeWWIvvCGtu4BI1xONjsunIyYUDbuLfdte+NQNT29hifeLfJte/dIDY+8foQblyXz0McWEWq38t9vFdBrYGVmzIDnPNXQxeaObp7cUkh3z6kXCel091Dm/athk3chkpK6NjbureDO1TP5xpq5lNa3889+de8orufSuXHcsDSFzQerBoT9dKGBrtQk8MC6+XzlI3MAzwnUmDA7//XqEZ7ZVsKrByoRGXpa34b2biJPMcxRRLhxWQrvHa319T/3zfu+aX8FFY0dVDV1cM/TO5ntCueH1y8gyGphefoMCmpasVmEZelRA54zMzZs2Mv/n95WwgMvHmDDrjK6e3p5/VAlvUOc+Dxe344xsCA5gkMVzRTVtLLZu8TdDctSuGxeHJmxYfzlQ88J4urmTo7VtrEsPYpL57lobO+edNMc+4MGulKTTKjdxp0XzeTd/Bq+9dc9zEtwcuPSFHaXNAw6gdjY1nXKFjrATctTEPH0Qbd39bC7tIG1CxLoMYYHNuznS3/eQWtnD7/+5DJC7Z6Tn32TlC1MifRt65Mec2Lo4sneyvNMm/37dwv5r1fz+MyTubx+aPD0T8XeD5fPr56JCPxlWzGvHaxidlw4mbFhWCzCTcuS+aCgjpK6Nt8c78vTZ7B6tgurRXjzsOe1jDHsKK7nQFnTqE8mj4WNe8v52abDY/oaOg5dqUno9vMzcPca5sSFs3qOixd2HeevO0oprGllZr+++YYR+tABkqNCuHB2LM9tL2VFRjTdPYabc1Kw2yy8sKsMq0X42c2LmNPvJG9fv/mqk7pbYODQxf4nTNu7ethaWEdSZDAHyps4UO5ZMGTzoSou986f06cv0M+bGcP1S5L543tFuHsMn1s903fMDctS+M9X83h+x3HautzYrRbmJ0USHGRlWVoUbxyu4htXzuWNw1V85slcwDO88sF181m7cOBw0vHwxLuFbC+u57Zz00mIDB6T19AWulKTUHCQlbsunc2a+QmE2K0sTfNcaNTXzeDu6WVLfo3npOgILXSAT65K43hDO197ZhcWgZyMaH5xy1IO/ftV7Lz/ikFzzyxJjeKzF2Zyy4rUQc813NDFrYW1dLl7+f66+cSE2YmPcHDB7BjePFw16C+L4to2goMsuJwO7rl8Du4eg7vXcEX2iatak6NCuGBWLH/eeoyN+8pZkBzhu2L3krlx7C9roqKxg1cPVBHusPHIrUuJczr40p93sNXbrTReOrp72FPaiDHw0r7ykR9whjTQlZoCZseFE2a38sr+Sr79/B5yfvQan3x8K8E2K8szokd8/JXzE/j6FVlUN3eSnRRBhHfcenCQ1Xe7P5vVwveuyfZNmdtfhm9e9IGB/nZeDQ6bhYuzXDz9hXN57ovnc93iZMobOzhYPnD2x+K6NtKiQxER0mPCuP38DFKjQ1iSOvAK2c+tzsQiQllDB1dkn1g4fO0Cz+3nd5by5uEqLpwdy7rFSTz9hfOwiGd++7Hk7unlqa3Fvi6evccb6erpxWaRASeg/U27XJSaAqwWYVFKFC/vryA4yMJV8xO4akEiF2e5CLFbR3y8iPDlj8xhTnz4WU/YFWK3khARTGFNG89tLyUkyEp8hIMXdh1n1cwYgoOszI7zdN84gjxtyjcOV5GdFOF7jr5A7/Odq8/hX6+aN2g8/SVz4/jg3z4yaObIma5wVmZG89s3j9LU4eaeyz2Tq4U7bMyJc475CdN3jtTwb3/bS2N7N1+6ZBa53rVWbzs3nSffK6KisWNMul000JWaIr50ySwWpUbymQsyiY84s7C4aoF/+pbTY0J5/VAlf91R6tuWMiOE71x9zoDj4pzBLE6J5MXdZXzp4llYLIIxhuK6Ns6bdaJ/3mIR7Jbhp/odahrgW1akcu8znjl1Ls460VWzJDWKVw74Z/rg4ez0nqT9n/eL+PzqTHKL6pjpCuNT53kCfePecj5zYabfX1e7XJSaIi7KcvHtteeccZj7U2ZsGPVt3cxyhfHb25bxtcuz+OeXVzM3YfDVs7efn8GhimZeOVDJm4er+NZze2jr6hnQQj8Taxck4gy2MS/BOaA1vCQtivq27tOab6aju4e9pY0YY2jrcnOgrAljDL29hrKG9kHH7yxpwG6zUNbYwd93lZF7rJ4V6dHMcoXz608u48ZlyWf1vQ1HW+hKKb+bm+DEIvDQxxazPH0GVy0Y/th1i5P45ev5fO+FfdS2dOIMDiI7MYIL+81SeSZC7FYe/cQywoMHxtyS1CgANu4r58PCOq5ekMjNOSnDttY73T187o+5vJtfw+y4cKqaOmjqcLM4JZLmTjcF1a08dNMiMmLDeOztozx43QJ2FTdw07Jk3s2v8c28uWqm51zG1WM4wkaGWi5rPOTk5Jjc3NyAvLZSamx1unsoa+ggc4iTpkP5285Svvb0bi7OcvHb25aPqt//TPX0GhY+sIm2rh5EwBi4cn48v71tOSJCa6ebu57awY3LUrh6QQJ3P7WTl/dX8JkLMtld2kBCRDBL06L43w+O4QwOwmGzsLOkAasIXT29XDA7hi35tfzs5sVkxobyQUEdc+LCuWxeHDbr2XeKiMh2Y0zOUPu0ha6U8juHzTrqMAe4fkkyyVGhLEmNOqOVmU6H1SIsTI4k91g9T9yxgu3H6nlk8xHeOVLDRVkufvVGPm8erua9/Fr+vvM4rx+q4nvXZPPZk/q8+8bEt3S6+dTvPSOKokKDeMk798zStChmucJZnj7yKCN/0UBXSgWciLAyc/yC7/5rs6lv7ebCObGcOzOav24v5b9eyyM6zM7j7xSwdkEC+8uaeP1QFV/5yJxBYd5fuMPG8186HxEhv6qZl/ZVEBkSRGbM6D/Q/EUDXSk17cxPOnEFq8Nm5e7LZvPt5/dyzS/fJSLYxg+um09ndy87iutZtzjpFM/k0df/PjvOySdWpRFkESynGJUzVjTQlVLT3seWp5Bf1UJCRDBXL0r0LTySegYjbX58w0J/lzdqo+qsEpGrROSwiOSLyH1D7L9IRHaIiFtEPub/MpVSauwEea98/fxFM0mOCgl0OWdsxEAXESvwKLAWyAZuFZHskw4rBu4AnvJ3gUoppUZnNF0uK4F8Y0wBgIisB64DDvQdYIwp8u6bfjPKK6XUBDGaLpdkoP/ifaXebadNRO4UkVwRya2urj6Tp1BKKTWMcb303xjzmDEmxxiT43K5xvOllVJqyhtNoB8H+k96nOLdppRSagIZTaBvA+aISKaI2IFbgA1jW5ZSSqnTNWKgG2PcwN3AJuAg8IwxZr+IPCgi6wBEZIWIlAI3A/8tIvvHsmillFKDjerCImPMRmDjSdvu73d7G56uGKWUUgESsNkWRaQaOHYGD40Favxcjj9oXadnotYFE7c2rev0TNS64OxqSzfGDDmqJGCBfqZEJHe4qSMDSes6PRO1Lpi4tWldp2ei1gVjV5uuWKSUUlOEBrpSSk0RkzHQHwt0AcPQuk7PRK0LJm5tWtfpmah1wRjVNun60JVSSg1tMrbQlVJKDUEDXSmlpohJE+gjLbIxjnWkisgbInJARPaLyFe92x8QkeMissv7dXWA6isSkb3eGnK926JF5FUROeL9d8Y41zS33/uyS0SaROSeQLxnIvKEiFSJyL5+24Z8f8TjEe/v3B4RWRaA2v5DRA55X/9vIhLl3Z4hIu393rvfjnNdw/7sROTb3vfssIhcOc51Pd2vpiIR2eXdPp7v13AZMfa/Z8aYCf8FWIGjwEzADuwGsgNUSyKwzHvbCeThWfjjAeAbE+C9KgJiT9r2EHCf9/Z9wE8D/LOsANID8Z4BFwHLgH0jvT/A1cBLgADnAlsDUNsawOa9/dN+tWX0Py4AdQ35s/P+X9gNOIBM7/9b63jVddL+/wTuD8D7NVxGjPnv2WRpofsW2TDGdAF9i2yMO2NMuTFmh/d2M575bc5ofvhxdB3wR+/tPwLXB64UPgIcNcacyVXCZ80Y8zZQd9Lm4d6f64A/GY8PgCgRSRzP2owxrxjPfEoAHxCAKTaGec+Gcx2w3hjTaYwpBPLx/P8d17pERICPA38Zi9c+lVNkxJj/nk2WQPfbIhv+JCIZwFJgq3fT3d4/mZ4Y726NfgzwiohsF5E7vdvijTHl3tsVQHxgSgM8s3X2/082Ed6z4d6fifZ79xk8Lbk+mSKyU0TeEpHVAahnqJ/dRHnPVgOVxpgj/baN+/t1UkaM+e/ZZAn0CUdEwoG/AvcYY5qA3wCzgCVAOZ4/9wLhQmPMMjxrwN4lIhf132k8f+MFZKyqeKZfXgc86900Ud4zn0C+P6ciIt8B3MCfvZvKgTRjzFLgXuApEYkYx5Im3M/uJLcysOEw7u/XEBnhM1a/Z5Ml0CfUIhsiEoTnB/VnY8zzAMaYSmNMjzGmF/gdY/Rn5kiMMce9/1YBf/PWUdn3J5z336pA1IbnQ2aHMabSW+OEeM8Y/v2ZEL93InIHcA3wSW8Q4O3SqPXe3o6nrzprvGo6xc8u4O+ZiNiAG4Gn+7aN9/s1VEYwDr9nkyXQJ8wiG96+ud8DB40xD/fb3r/P6wZg38mPHYfawkTE2Xcbzwm1fXjeq9u9h90OvDDetXkNaDVNhPfMa7j3ZwPwae8ohHOBxn5/Mo8LEbkK+BawzhjT1m+7S0Ss3tszgTlAwTjWNdzPbgNwi4g4RCTTW9eH41WX1+XAIWNMad+G8Xy/hssIxuP3bDzO+vrjC8+Z4Dw8n6zfCWAdF+L5U2kPsMv7dTXwP8Be7/YNQGIAapuJZ4TBbmB/3/sExACbgSPAa0B0AGoLA2qByH7bxv09w/OBUg504+mr/Oxw7w+eUQePen/n9gI5AagtH0//at/v2m+9x97k/RnvAnYA145zXcP+7IDveN+zw8Da8azLu/1J4IsnHTue79dwGTHmv2d66b9SSk0Rk6XLRSml1Ag00JVSaorQQFdKqSlCA10ppaYIDXSllJoiNNCVUmqK0EBXSqkp4v8DsTwJfeYxw5IAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train MSE: 0.1349\n",
      "Test MSE: 0.1159\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch_geometric_temporal.nn.recurrent import GConvLSTM\n",
    "\n",
    "from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader\n",
    "from torch_geometric_temporal.signal import temporal_signal_split\n",
    "\n",
    "loader = ChickenpoxDatasetLoader()\n",
    "\n",
    "lags = 10\n",
    "stride = 1\n",
    "epochs = 200\n",
    "\n",
    "dataset = loader.get_dataset(lags)\n",
    "\n",
    "train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)\n",
    "\n",
    "### MODEL DEFINITION\n",
    "class RecurrentGCN(nn.Module):\n",
    "    def __init__(self, node_features, window_length):\n",
    "        super(RecurrentGCN, self).__init__()\n",
    "        \n",
    "        self.node_features = node_features\n",
    "        self.n = window_length\n",
    "        self.hidden_state_size = 64\n",
    "        assert self.hidden_state_size % 2 == 0\n",
    "                \n",
    "        self.recurrent = GConvLSTM(node_features, self.hidden_state_size, 1)\n",
    "        self.linear1 = nn.Sequential(\n",
    "            nn.Linear(self.hidden_state_size, self.hidden_state_size//2), \n",
    "            nn.Dropout(p=0.2),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.out = nn.Sequential(\n",
    "            nn.Linear(self.hidden_state_size//2, 1),\n",
    "            nn.Flatten(start_dim=0, end_dim=-1)\n",
    "        )\n",
    "\n",
    "    def forward(self, window, h, c):        \n",
    "        edge_index = window.edge_index\n",
    "        edge_weight = None\n",
    "    \n",
    "        H, C = [], []\n",
    "        \n",
    "        for t in range(self.n):\n",
    "            x = window.x[:,t].unsqueeze(0).T\n",
    "            h, c = self.recurrent(x, edge_index, edge_weight, h, c)\n",
    "            \n",
    "            H.append(h.detach())\n",
    "            C.append(c.detach())\n",
    "            \n",
    "        x = self.linear1(h)\n",
    "        pred = self.out(x)\n",
    "\n",
    "        return pred, H, C\n",
    "    \n",
    "model = RecurrentGCN(node_features=1, window_length=10)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "\n",
    "### TRAIN\n",
    "model.train()\n",
    "\n",
    "loss_history = []\n",
    "for _ in tqdm(range(epochs)):\n",
    "    h, c = None, None\n",
    "    total_loss = 0\n",
    "    for i, window in enumerate(train_dataset):\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        y_pred, H, C = model(window, h, c)\n",
    "\n",
    "        h = H[stride-1]\n",
    "        c = C[stride-1]\n",
    "        \n",
    "        assert y_pred.shape == window.y.shape\n",
    "        loss = torch.mean((y_pred - window.y)**2)\n",
    "        total_loss += loss.item()\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    total_loss /= i+1\n",
    "    loss_history.append(total_loss)\n",
    "\n",
    "### TEST \n",
    "model.eval()\n",
    "loss = 0\n",
    "with torch.no_grad():\n",
    "    h, c = None, None\n",
    "    for i, window in enumerate(test_dataset):\n",
    "        y_pred, H, C = model(window, h, c)\n",
    "        h = H[stride-1]\n",
    "        c = C[stride-1]\n",
    "        \n",
    "        assert y_pred.shape == window.y.shape\n",
    "        loss += torch.mean((y_pred - window.y)**2)\n",
    "    loss /= i+1\n",
    "\n",
    "### RESULTS PLOT\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "x_ticks = np.arange(1, epochs+1)\n",
    "ax.plot(x_ticks, loss_history)\n",
    "\n",
    "# figure labels\n",
    "ax.set_title('Loss over time', fontweight='bold')\n",
    "ax.set_xlabel('Epochs', fontweight='bold')\n",
    "ax.set_ylabel('Mean Squared Error', fontweight='bold')\n",
    "plt.show()\n",
    "    \n",
    "print(\"Train MSE: {:.4f}\".format(total_loss))\n",
    "print(\"Test MSE: {:.4f}\".format(loss))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b3be8ce",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
